Mateusz Szymański
This homework focuses on LIME: Local Interpretable Model-agnostic Explanations. The goal of the task is to apply this XAI method to models trained in the previous Homeworks, and then compare the results with SHAP: SHapley Additive exPlanations.
Let us recall that we have selected four vanilla models:
sklearn)sklearn)XGBClassifier from xgboostTabPFN from tabpfnThe dataset phoneme.csv is consisting of 5 unnamed columns describing amplitudes of certain harmonics. The target variable is a binary variable classification of vowels: nasal or oral.
We are going to examine models behavior for selected observations. Of course, we can select some observations at random and we intend to do so. We've collected 250 random samples out of all
Besides, we are going to do a little better than that: by finding observations that:
The first condition tells that the following error:
$$\sum_{m\in M}\left|m(x)-y\right|^2$$is the highest, where $M$ is the set of all models, and $m(x)$ is the true class probability assigned by the model to the observation $x$, while $y$ is the true value. The higher value, the more models misclassify an observation.
The second condition describes the variance among predictions:
$$\sum_{m\in M}\left|m(x)-\overline{m}(x)\right|^2$$where $\overline{m}(x)$ is the average of probabilities assigned by the models.
By this method we gather 15 sample observations by taking:
One third of chosen observations belongs to the first second category which is close to the imbalance coefficient of the entire dataset (about $30\%$).
| V1 |
As LIME explanations are random in nature and explanations for the same observation differ significantly, we decided to average results by settings 1000 samples for each observation (num_samples=1000).
As a benchmark, to somehow quantify aforementioned differences, we calculated the average standard deviation of explanation values given by LIME for a logistic regression and randomly chosen 250 elements of our dataset.
| column | standard deviation |
|---|---|
| V1 | 0.010810 |
| V2 | 0.011232 |
| V3 | 0.011175 |
| V4 | 0.010665 |
| V5 | 0.011582 |
The standard deviation of around $0.01$ for each column is not negligible but tolerable.
Let us proceed to the results.
As the trained logistic regression model is defined by the model coefficients (and the bias), we expect that values given by LIME should stay in some relation to coefficients to the model.
Let us compare logistic regression coefficients with values given by LIME. To estimate the relationship we took a sample of 250 element since our initial 15 elements are certainly biased.
| column | logistic regression coefficient | LIME mean |
|---|---|---|
| V1 | -0.6998 | 0.1100 |
| V2 | -0.4370 | 0.0815 |
| V3 | 0.4746 | 0.0900 |
| V4 | 0.6726 | 0.1269 |
| V5 | 0.1958 | 0.0361 |
Even with a small sample we see that there is some relationship between the absolute value of the logistic regression coefficients and values given by LIME, especially for V4 and V5. The last columns seems to be in general the weakest one, while the penultimate one is often the strongest one.
The correlation coefficient for the absolute value of the coefficients and LIME is very high: $97\%$.
LIME values differ among randomly chosen samples but they are similar for columns in the same range. This is clearly evident on the following example:
| 3174 | 3507 |
|---|---|
Two observations share similar LIME values except for the second column.
We have used LIME to explain the linear regression model. We did the same for the random forest model. We can compare average LIME values with random forest feature importance values (which sum up to $1$). This comparison may not be fully justified as feature importance describes slightly different thing.
| column | feature importance | LIME mean |
|---|---|---|
| V1 | 0.1992 | 0.0745 |
| V2 | 0.1622 | 0.0788 |
| V3 | 0.2087 | 0.0793 |
| V4 | 0.3013 | 0.1166 |
| V5 | 0.1286 | 0.0563 |
Nevertheless, the correlation coefficient between LIME values and feature importances is high, around $95\%$.
LIME values are still close to each other for the same ranges of values, for each column. Let us take the same observation as in the previous example:
| 3174 | 3507 |
|---|---|
Again, except for the second column, the values are more or less the same.
| column | feature importance | LIME mean |
|---|---|---|
| V1 | 0.1662 | 0.1476 |
| V2 | 0.1427 | 0.0990 |
| V3 | 0.1517 | 0.0627 |
| V4 | 0.4055 | 0.1503 |
| V5 | 0.1340 | 0.0491 |
This time the correlation is still high but much less in comparison with previous models, namely $64.5\%$. We see that XGBoost measures V4 column as much more important than other ones, of similar magnitude.
Some discrepancies for the same ranges have been observed, but in general LIME values are stable.
| 416 | 3022 |
|---|---|
There is a visible difference for V4 column in this example.
Due to very slow TabPFN inference time, the number of samples (num_samples) has been dramatically reduced to 5.
We can see that, in contrary to previous models, TabPFN no longer has similar LIME values for the same column ranges. This can be easily seen on this example:
| 3507 | 3174 |
|---|---|
Despite the fact, that almost all values lie in the same range (except for V2), the LIME values differ, especially for V4.
However, this effect may be a result of a reduced number of samples per each observation.
SHapley Additive exPlanations (SHAPE) are based on Shapley values from game theory. The role of SHAP values is to determine marginal contribution of each agent. In this setting, this translates to contribution of a column to the predicted value by a classifier.
Note: TabPFN has been disabled due to very slow inference time.
All models except TabPFN have been analyzed using shap package.
| model | shap |
|---|---|
LogisticRegression |
|
RandomForestClassifier |
|
XGBClassifier |
Both LogisticRegression and RandomForestClassifier share the same order of columns. In each case V4 is the most important one while V5 is the least important.
For logistic regression, both SHAP and LIME values often agree to some extent. Let us see some examples (caution: SHAP columns are sorted):
| id | LIME | SHAP |
|---|---|---|
| 1299 |
We see that both V2 and V3 are positive while V4 and V5 are negative. It is not an exact correspondence but these explanations are close to each other.
| id | LIME | SHAP |
|---|---|---|
| 2392 |
There are, of course, more examples of that sort.
Marginal contribution works a bit different for the random forest model. This leads to LIME and SHAP disagreements, which is illustrated below:
| id | LIME | SHAP |
|---|---|---|
| 803 |
Both methods disagree on what is the strongest contributor and, in case of V2, V4 and V5, even on a direction of the contribution. XGBClassifier has many similar examples.
SHAP values describe marginal contribution of each component (with different combinations of contributors). However, LIME works in a different way, estimating an impact of in a local neighborhood of an observation, by perturbating observation values and comparing model predictions of such values.
These values are not directly comparable, but under certain conditions, they may agree to some extent, as the logistic regression model shows.
We've selected some observations. Now let us use LIME values to understand why RandomForestClassifier made a mistake for observations correctly tackled by LogisticRegression.
| V1 | V2 | V3 | V4 | V5 | TARGET | LogisticRegression | RandomForestClassifier | XGBClassifier | TabPFNClassifier | variance | error | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1552 | -0.55 | 0.46 | 1.62 | -1.27 | -0.14 | 1.00 | 0.10 | 0.97 | 1.00 | 0.98 | 0.15 | 2.91 |
LogisticRegression |
RandomForestClassifier |
|---|---|
Both models agree on all columns except the decisive one, V4, which happens to be the strongest variable.
Thanks to this information we could:
#1552 and see how they are classifiedRandomForestClassifier behavior in this region and see how well it worksV4 variable in the neighborhood of #1552This is beyond the scope of the homework and we leave that as it is.
import random
import shap
import dalex as dx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from tabpfn import TabPFNClassifier
from tqdm.notebook import tqdm
RANDOM_STATE = 42
NUMBER_OF_OBSERVATIONS = 15
SAMPLE_SIZE = 250
NUM_COLUMNS = 5
NUM_SAMPLES = 1000
FIGSIZE = (6, 3)
def seed(random_state=RANDOM_STATE):
np.random.seed(random_state)
random.seed(random_state)
seed()
df = pd.read_csv('datasets/phoneme.csv').iloc[:, 1:]
X = df.iloc[:, :-1]
y = df.iloc[:, -1].to_numpy() - 1
assert tuple(sorted(np.unique(y))) == (0, 1)
print(f'Dataset size: {len(y)} observations.')
Dataset size: 5404 observations.
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.75, random_state=RANDOM_STATE
)
sample_indices = np.random.choice(df.index, SAMPLE_SIZE)
logistic_regression = LogisticRegression(random_state=RANDOM_STATE)
logistic_regression.fit(X_train, y_train)
random_forest = RandomForestClassifier(random_state=RANDOM_STATE)
random_forest.fit(X_train, y_train)
xgb = XGBClassifier(random_state=RANDOM_STATE)
xgb.fit(X_train, y_train)
tab_pfn = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)
tab_pfn.fit(X_train, y_train, overwrite_warning=True)
display()
Loading model that can be used for inference only Using a Transformer with 25.82 M parameters
def get_proba_and_errors(model, X_test=X_test, y_test=y_test):
y_pred = model.predict_proba(X_test)[:, 1]
return y_pred, (y_pred - y_test) ** 2
logistic_regression_pred, logistic_regression_errors = get_proba_and_errors(logistic_regression)
random_forest_pred, random_forest_errors = get_proba_and_errors(random_forest)
xgb_pred, xgb_errors = get_proba_and_errors(xgb)
tab_pfn_pred, tab_pfn_errors = get_proba_and_errors(tab_pfn)
df_pred = pd.DataFrame([logistic_regression_pred, random_forest_pred, xgb_pred, tab_pfn_pred ]).T
df_errors = pd.DataFrame([logistic_regression_errors, random_forest_errors, xgb_errors, tab_pfn_errors]).T
df_pred.columns = ['LogisticRegression', 'RandomForestClassifier', 'XGBClassifier', 'TabPFNClassifier']
df_errors.columns = ['LogisticRegression', 'RandomForestClassifier', 'XGBClassifier', 'TabPFNClassifier']
variance = np.var(df_pred, axis=1)
total = np.sum(df_errors, axis=1)
df_pred['variance'] = variance
df_errors['error'] = total
seed()
indices = list(np.random.choice(df_errors.index, NUMBER_OF_OBSERVATIONS // 3))
df_pred_top = df_pred.sort_values('variance', ascending=False).head(NUMBER_OF_OBSERVATIONS // 3)
indices.extend(df_pred_top.index)
display(df_pred_top)
| LogisticRegression | RandomForestClassifier | XGBClassifier | TabPFNClassifier | variance | |
|---|---|---|---|---|---|
| 1552 | 0.096915 | 0.97 | 0.997193 | 0.981797 | 0.147307 |
| 2392 | 0.100094 | 0.97 | 0.991902 | 0.965859 | 0.143924 |
| 157 | 0.112273 | 0.91 | 0.995893 | 0.980029 | 0.136418 |
| 2813 | 0.121759 | 0.92 | 0.996616 | 0.952768 | 0.131375 |
| 3102 | 0.098721 | 0.84 | 0.985671 | 0.939502 | 0.129771 |
df_errors_top = df_errors.sort_values('error', ascending=False).head(NUMBER_OF_OBSERVATIONS // 3)
indices.extend(df_errors_top.index)
display(df_errors_top)
| LogisticRegression | RandomForestClassifier | XGBClassifier | TabPFNClassifier | error | |
|---|---|---|---|---|---|
| 1299 | 0.955875 | 0.9801 | 0.999413 | 0.992989 | 3.928377 |
| 2914 | 0.954347 | 0.9801 | 0.999683 | 0.992562 | 3.926692 |
| 416 | 0.909995 | 1.0000 | 0.998625 | 0.995772 | 3.904391 |
| 3022 | 0.954380 | 0.9409 | 0.999683 | 0.991938 | 3.886901 |
| 2829 | 0.882356 | 1.0000 | 0.998846 | 0.993345 | 3.874547 |
df_observations = pd.merge(
df.loc[indices],
df_pred.loc[indices],
left_index=True,
right_index=True
).merge(
df_errors[['error']],
left_index=True,
right_index=True
)
df_observations
| V1 | V2 | V3 | V4 | V5 | TARGET | LogisticRegression | RandomForestClassifier | XGBClassifier | TabPFNClassifier | variance | error | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.171170 | -0.221492 | -0.312070 | 1 | 0.135482 | 0.05 | 0.012702 | 0.016467 | 0.002442 | 0.021288 |
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.223830 | 1 | 0.124309 | 0.03 | 0.000142 | 0.001982 | 0.002560 | 0.016357 |
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 | 0.596135 | 0.55 | 0.613609 | 0.859307 | 0.014486 | 0.534699 |
| 1294 | -0.596311 | -0.477304 | 1.682110 | 0.493600 | -1.235860 | 2 | 0.516786 | 0.88 | 0.989788 | 0.874462 | 0.031811 | 0.263760 |
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 | 0.184514 | 0.38 | 0.147287 | 0.675601 | 0.043839 | 0.656576 |
| 1552 | -0.553795 | 0.459280 | 1.619252 | -1.268399 | -0.136583 | 1 | 0.096915 | 0.97 | 0.997193 | 0.981797 | 0.147307 | 2.908613 |
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 | 0.100094 | 0.97 | 0.991902 | 0.965859 | 0.143924 | 2.867673 |
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 | 0.112273 | 0.91 | 0.995893 | 0.980029 | 0.136418 | 2.792966 |
| 2813 | 0.153374 | 1.300046 | -1.359490 | -0.730233 | -0.312189 | 2 | 0.121759 | 0.92 | 0.996616 | 0.952768 | 0.131375 | 0.779950 |
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 | 0.098721 | 0.84 | 0.985671 | 0.939502 | 0.129771 | 0.841770 |
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 | 0.022311 | 0.01 | 0.000294 | 0.003512 | 0.000071 | 3.928377 |
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 | 0.023093 | 0.01 | 0.000159 | 0.003726 | 0.000076 | 3.926692 |
| 416 | -0.847626 | -1.055810 | -0.289216 | -1.872634 | -0.136583 | 1 | 0.046064 | 0.00 | 0.000688 | 0.002116 | 0.000382 | 3.904391 |
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 | 0.023076 | 0.03 | 0.000159 | 0.004039 | 0.000157 | 3.886901 |
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 | 0.060662 | 0.00 | 0.000577 | 0.003333 | 0.000662 | 3.874547 |
print(
'Class imbalance in the sample: {y1:.2f}%, the general imbalance is {y2:.2f}%.'.format(
y1=100 * y[indices].mean(),
y2=100 * y.mean()
)
)
Class imbalance in the sample: 33.33%, the general imbalance is 29.35%.
lr_explainer = dx.Explainer(logistic_regression, X, y)
display(lr_explainer.model_performance(cutoff=y.mean()))
rf_explainer = dx.Explainer(random_forest, X, y)
display(rf_explainer.model_performance(cutoff=y.mean()))
xgb_explainer = dx.Explainer(xgb, X, y)
display(xgb_explainer.model_performance(cutoff=y.mean()))
pfn_explainer = dx.Explainer(tab_pfn, X, y)
display(pfn_explainer.model_performance(cutoff=y.mean()))
Preparation of a new explainer is initiated -> data : 5404 rows 5 cols -> target variable : 5404 values -> model_class : sklearn.linear_model._logistic.LogisticRegression (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x7f80c99731a0> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 0.0126, mean = 0.306, max = 0.876 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.876, mean = -0.0122, max = 0.978 -> model_info : package sklearn A new explainer has been created!
X does not have valid feature names, but LogisticRegression was fitted with feature names
| recall | precision | f1 | accuracy | auc | |
|---|---|---|---|---|---|
| LogisticRegression | 0.773014 | 0.531657 | 0.63001 | 0.733531 | 0.808701 |
Preparation of a new explainer is initiated -> data : 5404 rows 5 cols -> target variable : 5404 values -> model_class : sklearn.ensemble._forest.RandomForestClassifier (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x7f80c99731a0> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray.
X does not have valid feature names, but RandomForestClassifier was fitted with feature names
-> predicted values : min = 0.0, mean = 0.307, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.98, mean = -0.014, max = 1.0 -> model_info : package sklearn A new explainer has been created!
| recall | precision | f1 | accuracy | auc | |
|---|---|---|---|---|---|
| RandomForestClassifier | 0.940731 | 0.708116 | 0.808015 | 0.868801 | 0.955736 |
Preparation of a new explainer is initiated -> data : 5404 rows 5 cols -> target variable : 5404 values -> model_class : xgboost.sklearn.XGBClassifier (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x7f80c99731a0> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 2.46e-05, mean = 0.3, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.999, mean = -0.00673, max = 1.0 -> model_info : package xgboost A new explainer has been created!
| recall | precision | f1 | accuracy | auc | |
|---|---|---|---|---|---|
| XGBClassifier | 0.883985 | 0.767798 | 0.821805 | 0.887491 | 0.947728 |
Preparation of a new explainer is initiated -> data : 5404 rows 5 cols -> target variable : 5404 values -> model_class : tabpfn.scripts.transformer_prediction_interface.TabPFNClassifier (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x7f80c99731a0> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 0.00012, mean = 0.31, max = 0.998 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.99, mean = -0.0162, max = 0.998 -> model_info : package tabpfn A new explainer has been created!
| recall | precision | f1 | accuracy | auc | |
|---|---|---|---|---|---|
| TabPFNClassifier | 0.904161 | 0.697471 | 0.787479 | 0.856773 | 0.943306 |
shap_lr_explainer = shap.Explainer(logistic_regression, X_train)
shap_lr_values = shap_lr_explainer(X)
shap_rf_explainer = shap.Explainer(random_forest, X_train)
shap_rf_values = shap_rf_explainer(X)
shap_xgb_explainer = shap.Explainer(xgb, X_train)
shap_xgb_values = shap_xgb_explainer(X)
96%|=================== | 5172/5404 [00:11<00:00]
def sort_explainables(explanation, columns=df.columns):
values = next(iter(explanation.as_map().values()))
order = dict([(key, index) for index, (key, value) in enumerate(values)])
dictionary = dict(sorted([
(columns[key], value)
for key, value in values
]))
return order, dictionary
def get_sorted_explainables(explanation):
order, dictionary = sort_explainables(explanation)
result = explanation.result
return result.loc[result.index.map(order)].set_index('variable')
def plot_explanation(
explainer,
observation,
draw_plot: bool = True,
num_samples: int = NUM_SAMPLES
):
explanation = explainer.predict_surrogate(observation, num_samples=num_samples)
result = explanation.result
if draw_plot:
result = get_sorted_explainables(explanation).iloc[::-1]
colors = (result['effect'] > 0).map({False: 'red', True: 'green'})
# explanation.plot method is not easily customizable
plt.figure(figsize=FIGSIZE)
plt.gca().yaxis.set_label_position('right')
plt.gca().yaxis.tick_right()
plt.barh(result.index, result['effect'], color=colors)
plt.xlim([-1.0, 1.0])
plt.show()
def show_explainables(explainer, X=X, indices=indices, num_samples: int = NUM_SAMPLES):
seed()
for index in indices:
observation = X.loc[index]
display(df.loc[[index]])
plot_explanation(explainer, observation, num_samples=num_samples)
variances = []
for index in tqdm(sample_indices):
observation = X.loc[index]
results = []
for i in range(25):
explanation = lr_explainer.predict_surrogate(observation, num_samples=NUM_SAMPLES)
result = get_sorted_explainables(explanation).reset_index(drop=True)
results.append(result)
variances.append(np.var(pd.concat(results, axis=1), axis=1))
0%| | 0/250 [00:00<?, ?it/s]
np.sqrt(np.sum(pd.concat(variances, axis=1), axis=1) / len(variances))
0 0.010810 1 0.011232 2 0.011175 3 0.010665 4 0.011582 dtype: float64
show_explainables(lr_explainer)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
def get_explainables(explainer, X=X, indices=indices):
seed()
results = []
for i in tqdm(indices):
observation = X.loc[i]
explanation = explainer.predict_surrogate(observation, num_samples=NUM_SAMPLES)
order, dictionary = sort_explainables(explanation)
result = explanation.result
result = result.loc[explanation.result.index.map(order)].set_index('variable').reset_index(drop=True)
results.append(result)
return results
def prepare_results(results, series: pd.Series, names: list[str]):
df_results = pd.DataFrame([
series,
np.mean(abs(pd.concat(results, axis=1)), axis=1)
]).T
df_results.index = df.columns[:NUM_COLUMNS]
df_results = df_results.reset_index()
df_results.columns = names
return df_results.set_index('column')
def calculate_correlation(df):
return df.corr().iloc[0, 1]
lr_results = get_explainables(lr_explainer, indices=sample_indices)
lr_series = pd.Series(logistic_regression.coef_[0])
df_lr = prepare_results(lr_results, lr_series, [
'column', 'logistic regression coefficient', 'LIME mean'
])
df_lr
0%| | 0/250 [00:00<?, ?it/s]
| logistic regression coefficient | LIME mean | |
|---|---|---|
| column | ||
| V1 | -0.699783 | 0.110007 |
| V2 | -0.436966 | 0.081500 |
| V3 | 0.474551 | 0.090017 |
| V4 | 0.672585 | 0.126876 |
| V5 | 0.195841 | 0.036118 |
print(
'Correlation coefficient: {:.2f}%.'.format(
100 * calculate_correlation(abs(df_lr))
)
)
Correlation coefficient: 97.08%.
show_explainables(rf_explainer)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
rf_results = get_explainables(rf_explainer, indices=sample_indices)
rf_series = pd.Series(random_forest.feature_importances_)
df_rf = prepare_results(rf_results, rf_series, [
'column', 'feature importance', 'LIME mean'
])
df_rf
0%| | 0/250 [00:00<?, ?it/s]
| feature importance | LIME mean | |
|---|---|---|
| column | ||
| V1 | 0.199188 | 0.074524 |
| V2 | 0.162220 | 0.078785 |
| V3 | 0.208704 | 0.079272 |
| V4 | 0.301254 | 0.116559 |
| V5 | 0.128634 | 0.056260 |
print(
'Correlation coefficient: {:.2f}%.'.format(
100 * calculate_correlation(df_rf)
)
)
Correlation coefficient: 95.49%.
show_explainables(xgb_explainer)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
xgb_results = get_explainables(xgb_explainer, indices=sample_indices)
xgb_series = pd.Series(xgb.feature_importances_)
df_xgb = prepare_results(xgb_results, xgb_series, [
'column', 'feature importance', 'LIME mean'
])
df_xgb
0%| | 0/250 [00:00<?, ?it/s]
| feature importance | LIME mean | |
|---|---|---|
| column | ||
| V1 | 0.166167 | 0.147629 |
| V2 | 0.142739 | 0.099043 |
| V3 | 0.151651 | 0.062731 |
| V4 | 0.405486 | 0.150314 |
| V5 | 0.133957 | 0.049092 |
print(
'Correlation coefficient: {:.2f}%.'.format(
100 * calculate_correlation(df_xgb)
)
)
Correlation coefficient: 64.49%.
show_explainables(pfn_explainer, num_samples=5)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
print('LogisticRegressions')
shap.summary_plot(shap_lr_values, sort=True)
print('RandomForestClassifier')
shap.summary_plot(shap_rf_values[:, :, 0], sort=True)
print('XGBClassifier')
shap.summary_plot(shap_xgb_values, sort=True)
LogisticRegressions
RandomForestClassifier
XGBClassifier
def show_shap_values(shap_values, X=X, indices=indices):
seed()
for index in indices:
observation = X.loc[index]
display(df.loc[[index]])
shap.plots.waterfall(shap_values[index])
show_shap_values(shap_lr_values)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
show_shap_values(shap_rf_values[:, :, 0])
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |
show_shap_values(shap_xgb_values)
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3174 | 1.344327 | 0.263042 | -1.17117 | -0.221492 | -0.31207 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3507 | 2.200125 | -0.391748 | -1.067369 | -0.594245 | -0.22383 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 860 | -0.753484 | -0.701102 | 0.926483 | 0.871314 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1294 | -0.596311 | -0.477304 | 1.68211 | 0.4936 | -1.23586 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1130 | 0.221449 | 1.430543 | -1.264106 | 0.034341 | 1.635539 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1552 | -0.553795 | 0.45928 | 1.619252 | -1.268399 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2392 | -0.784762 | -1.089178 | -0.321573 | -1.497742 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 157 | 0.574505 | -0.031709 | -0.358881 | -0.029576 | -0.815318 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2813 | 0.153374 | 1.300046 | -1.35949 | -0.730233 | -0.312189 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3102 | -0.261839 | -0.824677 | 1.649498 | 0.906567 | 0.830566 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 1299 | -0.236732 | -0.594257 | 1.544963 | -1.186315 | -0.536999 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2914 | -0.777622 | -0.295839 | 0.262987 | -0.118778 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 416 | -0.847626 | -1.05581 | -0.289216 | -1.872634 | -0.136583 | 1 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 3022 | -0.598224 | -0.375108 | 0.942325 | -1.732816 | 2.614465 | 2 |
| V1 | V2 | V3 | V4 | V5 | TARGET | |
|---|---|---|---|---|---|---|
| 2829 | -0.687936 | -0.938647 | 0.932945 | 1.108643 | 1.099898 | 1 |